{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cell_id": "0078a1421f5a4d3f97cc22babb429849",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 61
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 1608,
    "execution_start": 1672343728631,
    "id": "DD9A7AB629ED4CC89D8FB81DD47E8CFC",
    "jupyter": {},
    "notebookId": "6069a682a7fbf6001890aecf",
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "source_hash": "995dab41",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 891 entries, 0 to 890\n",
      "Data columns (total 12 columns):\n",
      " #   Column       Non-Null Count  Dtype  \n",
      "---  ------       --------------  -----  \n",
      " 0   PassengerId  891 non-null    int64  \n",
      " 1   Survived     891 non-null    int64  \n",
      " 2   Pclass       891 non-null    int64  \n",
      " 3   Name         891 non-null    object \n",
      " 4   Sex          891 non-null    object \n",
      " 5   Age          714 non-null    float64\n",
      " 6   SibSp        891 non-null    int64  \n",
      " 7   Parch        891 non-null    int64  \n",
      " 8   Ticket       891 non-null    object \n",
      " 9   Fare         891 non-null    float64\n",
      " 10  Cabin        204 non-null    object \n",
      " 11  Embarked     889 non-null    object \n",
      "dtypes: float64(2), int64(5), object(5)\n",
      "memory usage: 83.7+ KB\n",
      "None\n",
      "   PassengerId  Survived  Pclass  \\\n",
      "0            1         0       3   \n",
      "1            2         1       1   \n",
      "2            3         1       3   \n",
      "3            4         1       1   \n",
      "4            5         0       3   \n",
      "\n",
      "                                                Name     Sex   Age  SibSp  \\\n",
      "0                            Braund, Mr. Owen Harris    male  22.0      1   \n",
      "1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   \n",
      "2                             Heikkinen, Miss. Laina  female  26.0      0   \n",
      "3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   \n",
      "4                           Allen, Mr. William Henry    male  35.0      0   \n",
      "\n",
      "   Parch            Ticket     Fare Cabin Embarked  \n",
      "0      0         A/5 21171   7.2500   NaN        S  \n",
      "1      0          PC 17599  71.2833   C85        C  \n",
      "2      0  STON/O2. 3101282   7.9250   NaN        S  \n",
      "3      0            113803  53.1000  C123        S  \n",
      "4      0            373450   8.0500   NaN        S  \n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "# 读取数据\n",
    "data = pd.read_csv('titanic/train.csv')\n",
    "# 查看数据集信息和前5行具体内容，其中NaN代表数据缺失\n",
    "print(data.info())\n",
    "print(data[:5])\n",
    "\n",
    "# 删去编号、姓名、船票编号3列\n",
    "data.drop(columns=['PassengerId', 'Name', 'Ticket'], inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "cell_id": "094db09de1cf4921ab7d3ca98100920c",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 73
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 3,
    "execution_start": 1672344091744,
    "source_hash": "d5cc24ad",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Age ：\n",
      "0.4200\n",
      "9.2622\n",
      "18.1044\n",
      "26.9467\n",
      "35.7889\n",
      "44.6311\n",
      "53.4733\n",
      "62.3156\n",
      "71.1578\n",
      "80.0000\n",
      "Fare ：\n",
      "0.0000\n",
      "56.9255\n",
      "113.8509\n",
      "170.7764\n",
      "227.7019\n",
      "284.6273\n",
      "341.5528\n",
      "398.4783\n",
      "455.4037\n",
      "512.3292\n"
     ]
    }
   ],
   "source": [
    "feat_ranges = {}\n",
    "cont_feat = ['Age', 'Fare'] # 连续特征\n",
    "bins = 10 # 分类点数\n",
    "\n",
    "for feat in cont_feat:\n",
    "    # 数据集中存在缺省值nan，需要用np.nanmin和np.nanmax\n",
    "    min_val = np.nanmin(data[feat]) \n",
    "    max_val = np.nanmax(data[feat])\n",
    "    feat_ranges[feat] = np.linspace(min_val, max_val, bins).tolist()\n",
    "    print(feat, '：') # 查看分类点\n",
    "    for spt in feat_ranges[feat]:\n",
    "        print(f'{spt:.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "cell_id": "a894a8ad900048948c1f0bc8e39bff83",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 85
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 34,
    "execution_start": 1671093938393,
    "source_hash": "8d26937d",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sex：Index(['female', 'male'], dtype='object')\n",
      "Pclass：Int64Index([1, 2, 3], dtype='int64')\n",
      "SibSp：Int64Index([0, 1, 2, 3, 4, 5, 8], dtype='int64')\n",
      "Parch：Int64Index([0, 1, 2, 3, 4, 5, 6], dtype='int64')\n",
      "Cabin：Index(['A10', 'A14', 'A16', 'A19', 'A20', 'A23', 'A24', 'A26', 'A31', 'A32',\n",
      "       ...\n",
      "       'E8', 'F E69', 'F G63', 'F G73', 'F2', 'F33', 'F38', 'F4', 'G6', 'T'],\n",
      "      dtype='object', length=147)\n",
      "Embarked：Index(['C', 'Q', 'S'], dtype='object')\n"
     ]
    }
   ],
   "source": [
    "# 只有有限取值的离散特征\n",
    "cat_feat = ['Sex', 'Pclass', 'SibSp', 'Parch', 'Cabin', 'Embarked'] \n",
    "for feat in cat_feat:\n",
    "    data[feat] = data[feat].astype('category') # 数据格式转为分类格式\n",
    "    print(f'{feat}：{data[feat].cat.categories}') # 查看类别\n",
    "    data[feat] = data[feat].cat.codes.to_list() # 将类别按顺序转换为整数\n",
    "    ranges = list(set(data[feat]))\n",
    "    ranges.sort()\n",
    "    feat_ranges[feat] = ranges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "cell_id": "91df2b70b61a4048abc900755337f060",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 97
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 1,
    "execution_start": 1671093938430,
    "source_hash": "5b278742",
    "tags": []
   },
   "outputs": [],
   "source": [
    "# 将所有缺省值替换为-1\n",
    "data.fillna(-1, inplace=True)\n",
    "for feat in feat_ranges.keys():\n",
    "    feat_ranges[feat] = [-1] + feat_ranges[feat]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "cell_id": "b45db4d7938b4688a6c6661a16f0f71d",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 109
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 78,
    "execution_start": 1671093938435,
    "source_hash": "d1aadcb7",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集大小： 712\n",
      "测试集大小： 179\n",
      "特征数： 8\n"
     ]
    }
   ],
   "source": [
    "# 划分训练集与测试集\n",
    "np.random.seed(0)\n",
    "feat_names = data.columns[1:]\n",
    "label_name = data.columns[0]\n",
    "# 重排下标之后，按新的下标索引数据\n",
    "data = data.reindex(np.random.permutation(data.index))\n",
    "ratio = 0.8\n",
    "split = int(ratio * len(data))\n",
    "train_x = data[:split].drop(columns=['Survived']).to_numpy()\n",
    "train_y = data['Survived'][:split].to_numpy()\n",
    "test_x = data[split:].drop(columns=['Survived']).to_numpy()\n",
    "test_y = data['Survived'][split:].to_numpy()\n",
    "print('训练集大小：', len(train_x))\n",
    "print('测试集大小：', len(test_x))\n",
    "print('特征数：', train_x.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "cell_id": "920320cfbb5940f8b3587fc4b25d18bc",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 121
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 1,
    "execution_start": 1671093938509,
    "source_hash": "3a580bf",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class Node:\n",
    "\n",
    "    def __init__(self):\n",
    "        # 内部结点的feat表示用来分类的特征编号，其数字与数据中的顺序对应\n",
    "        # 叶结点的feat表示该结点对应的分类结果\n",
    "        self.feat = None\n",
    "        # 分类值列表，表示按照其中的值向子结点分类\n",
    "        self.split = None\n",
    "        # 子结点列表，叶结点的child为空\n",
    "        self.child = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "cell_id": "f0a03026c20d402981c03219041cd932",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 133
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 40,
    "execution_start": 1671093938510,
    "source_hash": "18107b07",
    "tags": []
   },
   "outputs": [],
   "source": [
    "class DecisionTree:\n",
    "\n",
    "    def __init__(self, X, Y, feat_ranges, lbd):\n",
    "        self.root = Node()\n",
    "        self.X = X\n",
    "        self.Y = Y\n",
    "        self.feat_ranges = feat_ranges # 特征取值范围\n",
    "        self.lbd = lbd # 正则化系数\n",
    "        self.eps = 1e-8 # 防止数学错误log(0)和除以0\n",
    "        self.T = 0 # 记录叶结点个数\n",
    "        self.ID3(self.root, self.X, self.Y)\n",
    "\n",
    "    # 工具函数，计算 a * log a\n",
    "    def aloga(self, a):\n",
    "        return a * np.log2(a + self.eps)\n",
    "\n",
    "    # 计算某个子数据集的熵\n",
    "    def entropy(self, Y):\n",
    "        cnt = np.unique(Y, return_counts=True)[1] # 统计每个类别出现的次数\n",
    "        N = len(Y)\n",
    "        ent = -np.sum([self.aloga(Ni / N) for Ni in cnt])\n",
    "        return ent\n",
    "\n",
    "    # 计算用feat <= val划分数据集的信息增益\n",
    "    def info_gain(self, X, Y, feat, val):\n",
    "        # 划分前的熵\n",
    "        N = len(Y)\n",
    "        if N == 0:\n",
    "            return 0\n",
    "        HX = self.entropy(Y)\n",
    "        HXY = 0 # H(X|Y)\n",
    "        # 分别计算H(X|X_F<=val)和H(X|X_F>val)\n",
    "        Y_l = Y[X[:, feat] <= val]\n",
    "        HXY += len(Y_l) / len(Y) * self.entropy(Y_l)\n",
    "        Y_r = Y[X[:, feat] > val]\n",
    "        HXY += len(Y_r) / len(Y) * self.entropy(Y_r)\n",
    "        return HX - HXY\n",
    "\n",
    "    # 计算特征feat <= val本身的复杂度H_Y(X)\n",
    "    def entropy_YX(self, X, Y, feat, val):\n",
    "        HYX = 0\n",
    "        N = len(Y)\n",
    "        if N == 0:\n",
    "            return 0\n",
    "        Y_l = Y[X[:, feat] <= val]\n",
    "        HYX += -self.aloga(len(Y_l) / N)\n",
    "        Y_r = Y[X[:, feat] > val]\n",
    "        HYX += -self.aloga(len(Y_r) / N)\n",
    "        return HYX\n",
    "\n",
    "    # 计算用feat <= val划分数据集的信息增益率\n",
    "    def info_gain_ratio(self, X, Y, feat, val):\n",
    "        IG = self.info_gain(X, Y, feat, val)\n",
    "        HYX = self.entropy_YX(X, Y, feat, val)\n",
    "        return IG / HYX\n",
    "\n",
    "    # 用ID3算法递归分裂结点，构造决策树\n",
    "    def ID3(self, node, X, Y):\n",
    "        # 判断是否已经分类完成\n",
    "        if len(np.unique(Y)) == 1:\n",
    "            node.feat = Y[0]\n",
    "            self.T += 1\n",
    "            return\n",
    "        \n",
    "        # 寻找最优分类特征和分类点\n",
    "        best_IGR = 0\n",
    "        best_feat = None\n",
    "        best_val = None\n",
    "        for feat in range(len(feat_names)):\n",
    "            for val in self.feat_ranges[feat_names[feat]]:\n",
    "                IGR = self.info_gain_ratio(X, Y, feat, val)\n",
    "                if IGR > best_IGR:\n",
    "                    best_IGR = IGR\n",
    "                    best_feat = feat\n",
    "                    best_val = val\n",
    "        \n",
    "        # 计算用best_feat <= best_val分类带来的代价函数变化\n",
    "        # 由于分裂叶结点只涉及该局部，我们只需要计算分裂前后该结点的代价函数\n",
    "        # 当前代价\n",
    "        cur_cost = len(Y) * self.entropy(Y) + self.lbd\n",
    "        # 分裂后的代价，按best_feat的取值分类统计\n",
    "        # 如果best_feat为None，说明最优的信息增益率为0，\n",
    "        # 再分类也无法增加信息了，因此将new_cost设置为无穷大\n",
    "        if best_feat is None:\n",
    "            new_cost = np.inf\n",
    "        else:\n",
    "            new_cost = 0\n",
    "            X_feat = X[:, best_feat]\n",
    "            # 获取划分后的两部分，计算新的熵\n",
    "            new_Y_l = Y[X_feat <= best_val]\n",
    "            new_cost += len(new_Y_l) * self.entropy(new_Y_l)\n",
    "            new_Y_r = Y[X_feat > best_val]\n",
    "            new_cost += len(new_Y_r) * self.entropy(new_Y_r)\n",
    "            # 分裂后会有两个叶结点\n",
    "            new_cost += 2 * self.lbd\n",
    "\n",
    "        if new_cost <= cur_cost:\n",
    "            # 如果分裂后代价更小，那么执行分裂\n",
    "            node.feat = best_feat\n",
    "            node.split = best_val\n",
    "            l_child = Node()\n",
    "            l_X = X[X_feat <= best_val]\n",
    "            l_Y = Y[X_feat <= best_val]\n",
    "            self.ID3(l_child, l_X, l_Y)\n",
    "            r_child = Node()\n",
    "            r_X = X[X_feat > best_val]\n",
    "            r_Y = Y[X_feat > best_val]\n",
    "            self.ID3(r_child, r_X, r_Y)\n",
    "            node.child = [l_child, r_child]\n",
    "        else:\n",
    "            # 否则将当前结点上最多的类别作为该结点的类别\n",
    "            vals, cnt = np.unique(Y, return_counts=True)\n",
    "            node.feat = vals[np.argmax(cnt)]\n",
    "            self.T += 1\n",
    "\n",
    "    # 预测新样本的分类\n",
    "    def predict(self, x):\n",
    "        node = self.root\n",
    "        # 从根结点开始向下寻找，到叶结点结束\n",
    "        while node.split is not None:\n",
    "            # 判断x应该处于哪个子结点\n",
    "            if x[node.feat] <= node.split:\n",
    "                node = node.child[0]\n",
    "            else:\n",
    "                node = node.child[1]\n",
    "        # 到达叶结点，返回类别\n",
    "        return node.feat\n",
    "\n",
    "    # 计算在样本X，标签Y上的准确率\n",
    "    def accuracy(self, X, Y):\n",
    "        correct = 0\n",
    "        for x, y in zip(X, Y):\n",
    "            pred = self.predict(x)\n",
    "            if pred == y:\n",
    "                correct += 1\n",
    "        return correct / len(Y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "cell_id": "d9f9571da56c4ada86801cb73922afab",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 145
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 970,
    "execution_start": 1671093938550,
    "source_hash": "77088117",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "叶结点数量： 23\n",
      "训练集准确率： 0.8300561797752809\n",
      "测试集准确率： 0.7262569832402235\n"
     ]
    }
   ],
   "source": [
    "DT = DecisionTree(train_x, train_y, feat_ranges, lbd=1.0)\n",
    "print('叶结点数量：', DT.T)\n",
    "\n",
    "# 计算在训练集和测试集上的准确率\n",
    "print('训练集准确率：', DT.accuracy(train_x, train_y))\n",
    "print('测试集准确率：', DT.accuracy(test_x, test_y))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "cell_id": "5ddbd737be6d4714a801dc9662c0428a",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 157
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 1085,
    "execution_start": 1671093939524,
    "id": "9B8D853EBB02453FA927D4E75A2F8058",
    "jupyter": {},
    "notebookId": "6069a682a7fbf6001890aecf",
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "source_hash": "3ad7b7c5",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练集准确率：C4.5：0.8792134831460674，CART：0.8848314606741573\n",
      "测试集准确率：C4.5：0.7150837988826816，CART：0.7877094972067039\n"
     ]
    }
   ],
   "source": [
    "from sklearn import tree\n",
    "\n",
    "# criterion表示分类依据，max_depth表示树的最大深度\n",
    "# entropy生成的是C4.5分类树\n",
    "c45 = tree.DecisionTreeClassifier(criterion='entropy', max_depth=6)\n",
    "c45.fit(train_x, train_y)\n",
    "# gini生成的是CART分类树\n",
    "cart = tree.DecisionTreeClassifier(criterion='gini', max_depth=6)\n",
    "cart.fit(train_x, train_y)\n",
    "\n",
    "c45_train_pred = c45.predict(train_x)\n",
    "c45_test_pred = c45.predict(test_x)\n",
    "cart_train_pred = cart.predict(train_x)\n",
    "cart_test_pred = cart.predict(test_x)\n",
    "print(f'训练集准确率：C4.5：{np.mean(c45_train_pred == train_y)}，' \\\n",
    "    f'CART：{np.mean(cart_train_pred == train_y)}')\n",
    "print(f'测试集准确率：C4.5：{np.mean(c45_test_pred == test_y)}，' \\\n",
    "    f'CART：{np.mean(cart_test_pred == test_y)}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "cell_id": "30c23d316f0b4744a68395616dd32bfc",
    "deepnote_app_coordinates": {
     "h": 5,
     "w": 12,
     "x": 0,
     "y": 169
    },
    "deepnote_cell_type": "code",
    "deepnote_to_be_reexecuted": false,
    "execution_millis": 5575,
    "execution_start": 1671095780757,
    "id": "9304CE0BF1DD4C2B869A4CEBBBCC0D48",
    "jupyter": {},
    "notebookId": "6069a682a7fbf6001890aecf",
    "scrolled": false,
    "slideshow": {
     "slide_type": "slide"
    },
    "source_hash": "95e93f04",
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pydotplus in f:\\anaconda3\\lib\\site-packages (2.0.2)\n",
      "Requirement already satisfied: pyparsing>=2.0.1 in f:\\anaconda3\\lib\\site-packages (from pydotplus) (3.0.9)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "!pip install pydotplus\n",
    "\n",
    "from six import StringIO\n",
    "import pydotplus\n",
    "\n",
    "dot_data = StringIO()\n",
    "tree.export_graphviz( # 导出sklearn的决策树的可视化数据\n",
    "    c45,\n",
    "    out_file=dot_data,\n",
    "    feature_names=feat_names,\n",
    "    class_names=['non-survival', 'survival'],\n",
    "    filled=True, \n",
    "    rounded=True,\n",
    "    impurity=False\n",
    ")\n",
    "# 用pydotplus生成图像\n",
    "graph = pydotplus.graph_from_dot_data(\n",
    "    dot_data.getvalue().replace('\\n', '')) \n",
    "graph.write_png('tree.png')"
   ]
  }
 ],
 "metadata": {
  "deepnote": {},
  "deepnote_app_layout": "article",
  "deepnote_execution_queue": [],
  "deepnote_notebook_id": "7bfdffc32eae4d6c823e7278fe118e55",
  "deepnote_persisted_session": {
   "createdAt": "2022-12-29T20:30:11.380Z"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
